home *** CD-ROM | disk | FTP | other *** search
/ Turnbull China Bikeride / Turnbull China Bikeride - Disc 1.iso / HENSA / MISC / PDP.ARC / Source_c_pa < prev    next >
Encoding:
Text File  |  1987-12-23  |  13.9 KB  |  578 lines

  1. /*
  2.  
  3.        This file is part of the PDP software package.
  4.          
  5.        Copyright 1987 by James L. McClelland and David E. Rumelhart.
  6.        
  7.        Please refer to licensing information in the file license.txt,
  8.        which is in the same directory with this source file and is
  9.        included here by reference.
  10. */
  11.  
  12.  
  13. /* file: pa.c
  14.  
  15.     Do the actual work for the pa program.
  16.     
  17.     First version implemented by Elliot Jaffe.
  18.     
  19.     Date of last revision:  8-12-87/JLM.
  20. */
  21.  
  22. #include "general.h"
  23. #include "pa.h"
  24. #include "variable.h"
  25. #include "weights.h"
  26. #include "patterns.h"
  27. #include "command.h"
  28. #include <math.h>
  29.  
  30. char   *Prompt = "pa: ";
  31. boolean System_Defined = FALSE;
  32. char   *Default_step_string = "epoch";
  33. boolean lflag = 1;
  34. boolean linear = 0;
  35. boolean    lt = 0;
  36. boolean cs = 0;
  37. boolean hebb = 0;
  38. int     epochno = 0;
  39. int     nepochs = 500;
  40. int     patno = 0;
  41. float    ndp = 0;
  42. float    nvl = 0;
  43. float    vcor = 0;
  44. float   tss = 0.0;
  45. float   pss = 0.0;
  46. float   ecrit = 0.0;
  47. float  *netinput = NULL;
  48. float  *output = NULL;
  49. float  *error = NULL;
  50. float  *input = NULL;
  51. float  *target = NULL;
  52. float    noise = 0;
  53. float   temp = 15.0;
  54. int    tallflag = 0;
  55.  
  56.  
  57. extern int read_weights();
  58. extern int write_weights();
  59.  
  60. float *
  61. readvec(pstr,len) char *pstr; int len; {
  62.     int j;
  63.     float *tvec;
  64.     char *str;
  65.     char tstr[60];
  66.     
  67.     if (pstr == NULL) {
  68.         tvec = (float *) emalloc((unsigned)(sizeof(float)*len));
  69.     for (j = 0; j < len; j++) {
  70.         tvec[j] = 0.0;
  71.     }
  72.     return(tvec);
  73.     }
  74.     sprintf(tstr,"give %selements:  ",pstr);
  75.     tvec = (float *) emalloc((unsigned)(sizeof(float)*len));
  76.     for (j = 0; j < len; j++) {
  77.     tvec[j] = 0.0;
  78.     }
  79.     for (j = 0; j <= len; j++) {
  80.         str = get_command(tstr);
  81.     if (str == NULL || strcmp(str,"end") == 0) {
  82.         if (j) return(tvec); else return (NULL);
  83.     }
  84.     if (strcmp("+",str) == 0) tvec[j] = 1.0;
  85.     else if (strcmp("-",str) == 0) tvec[j] = -1.0;
  86.     else if (strcmp(".",str) == 0) tvec[j] = 0.0;
  87.     else sscanf(str,"%f",&tvec[j]);
  88.     }
  89.     return(tvec);
  90. }
  91.  
  92. float *
  93. get_vec() {
  94.     char * str;
  95.     int j;
  96.     str = 
  97.       get_command("vector (iN for ipattern, tN for tpattern, E for enter): ");
  98.     if (str == NULL) return(NULL);
  99.     if(*str == 'i') {
  100.     if((patno = get_pattern_number(++str)) < 0) {
  101.         put_error("Invalid pattern specification.");
  102.         return(NULL);
  103.     }
  104.         return(ipattern[patno]);
  105.     }
  106.     else if(*str == 't') {
  107.     if((patno = get_pattern_number(++str)) < 0) {
  108.         put_error("Invalid pattern specification.");
  109.         return(NULL);
  110.     }
  111.         return(tpattern[patno]);
  112.     }
  113.     else return(readvec(" ",nunits));
  114. }
  115.  
  116. float
  117. dotprod(v1,v2,len) float *v1, *v2; int len; {
  118.     register int i;
  119.     double dp = 0;
  120.     double denom;
  121.     denom = (double) len;
  122.     if (denom == 0) return(0.0);
  123.     for (i = 0; i < len; i++,v1++,v2++) {
  124.         dp += (double) ((*v1)*(*v2));
  125.     }
  126.     dp /= denom;
  127.     return(dp);
  128. }
  129.  
  130. float
  131. sumsquares(v1,v2,len) float *v1, *v2; int len; {
  132.     register int i;
  133.     double ss = 0;
  134.  
  135.     for (i = 0; i < len; i++,v1++,v2++) {
  136.         ss += (double)((*v1 - *v2) * (*v1 - *v2));
  137.     }
  138.     return(ss);
  139. }
  140.  
  141. /* the following function computes the vector correlation, or the
  142.    cosine of the angle between v1 and v2 */
  143.  
  144. float
  145. veccor(v1,v2,len) float *v1, *v2; int len; {
  146.     register int i;
  147.     double denom;
  148.     double dp = 0.0;
  149.     double l1 = 0.0;
  150.     double l2 = 0.0;
  151.  
  152.     for (i = 0; i < len; i++,v1++,v2++) {
  153.         dp += (double) (*v1)*(*v2);
  154.         l1 += (double) (*v1)*(*v1);
  155.         l2 += (double) (*v2)*(*v2);
  156.     }
  157.     if (l1 == 0.0 || l2 == 0.0) return (0.0);
  158.     dp /= sqrt(l1*l2);
  159.     return(dp);
  160. }
  161.  
  162. float
  163. veclen(v,len) float *v; int len; {
  164.     int i;
  165.     double denom;
  166.     double vl = 0;
  167.     denom = (double) len;
  168.     if (denom == 0) {
  169.         return(0.0);
  170.     }
  171.     for (i = 0; i < len; i++,v++) {
  172.         vl += (*v)*(*v)/denom;
  173.     }
  174.     vl = sqrt((vl));
  175.     return(vl);
  176. }
  177.  
  178. distort(vect,pattern,len,amount) 
  179. float *vect;
  180. float *pattern;
  181. int len;
  182. float   amount;
  183. {
  184.     int    i;
  185.     float   rval,val;
  186.  
  187.     for (i = 0; i < len; i++) {
  188.     rval = (float) (1.0 - 2.0*rnd());
  189.     *vect++ = *pattern++ + rval*amount;
  190.     }
  191. }
  192.  
  193. init_system() {
  194.     int     strain (), ptrain (), tall (), get_unames(),
  195.             test_pattern (), reset_weights(),newstart();
  196.     int change_lrate();
  197.  
  198.     lrate = 2.0;
  199.     epsilon_menu = NOMENU;
  200.     (void) install_var("lflag", Int,(int *) & lflag, 0, 0, SETPCMENU);
  201.  
  202.     (void) install_command("strain", strain, BASEMENU,(int *) NULL);
  203.     (void) install_command("ptrain", ptrain, BASEMENU,(int *) NULL);
  204.     (void) install_command("tall", tall, BASEMENU,(int *) NULL);
  205.     (void) install_command("test", test_pattern, BASEMENU,(int *) NULL);
  206.     (void) install_command("reset",reset_weights,BASEMENU,(int *)NULL);
  207.     (void) install_command("newstart",newstart,BASEMENU,(int *)NULL);
  208.     (void) install_command("patterns", get_pattern_pairs, 
  209.                            GETMENU,(int *) NULL);
  210.     (void) install_command("unames", get_unames, GETMENU,(int *) NULL);
  211.     (void) install_var("nepochs", Int,(int *) & nepochs, 0, 0, SETPCMENU);
  212.     (void) install_command("lrate", change_lrate, SETPARAMMENU, (int *) NULL);
  213.     (void) install_var("lrate", Float,(int *) & lrate, 0, 0, NOMENU);
  214.     (void) install_var("ecrit", Float, (int *)& ecrit,0,0,SETPCMENU);
  215.     (void) install_var("noise", Float, (int *)&noise,0,0,SETPARAMMENU);
  216.     (void) install_var("linear", Int,(int *) &linear,0,0,SETMODEMENU);
  217.     (void) install_var("temp", Float, (int *)&temp,0,0,SETPARAMMENU);
  218.     (void) install_var("lt", Int,(int *) <,0,0,SETMODEMENU);
  219.     (void) install_var("cs", Int,(int *) &cs,0,0,SETMODEMENU);
  220.     (void) install_var("hebb", Int,(int *) &hebb,0,0,SETMODEMENU);
  221.     (void) install_var("epochno", Int,(int *) & epochno, 0, 0, SETSVMENU);
  222.     (void) install_var("patno", Int,(int *) & patno, 0, 0, SETSVMENU);
  223.     init_pattern_pairs();
  224.     (void) install_var("tss", Float,(int *) & tss, 0, 0, SETSVMENU);
  225.     (void) install_var("pss", Float,(int *) & pss, 0, 0, SETSVMENU);
  226.     (void) install_var("ndp", Float,(int *) & ndp, 0, 0, SETSVMENU);
  227.     (void) install_var("vcor", Float,(int *) & vcor, 0, 0, SETSVMENU);
  228.     (void) install_var("nvl", Float,(int *) & nvl, 0, 0, SETSVMENU);
  229.     init_weights();
  230. }
  231.  
  232. define_system() {
  233.     register int    i,j;
  234.  
  235.     if (!nunits) {
  236.     put_error("cannot init pa system, nunits not defined");
  237.     return(FALSE);
  238.     }
  239.     else
  240.     if (!noutputs) {
  241.         put_error("cannot init pa system, noutputs not defined");
  242.         return(FALSE);
  243.     }
  244.     else
  245.     if (!ninputs) {
  246.         put_error("cannot init pa system, ninputs not defined");
  247.         return(FALSE);
  248.     }
  249.     netinput = (float *) emalloc((unsigned)(sizeof(float) * nunits));
  250.     (void) install_var("netinput", Vfloat,(int *) netinput,
  251.         nunits, 0, SETSVMENU);
  252.     for (i = 0; i < nunits; i++)
  253.     netinput[i] = 0.0;
  254.  
  255.     output = (float *) emalloc((unsigned)(sizeof(float) * nunits));
  256.     (void) install_var("output", Vfloat,(int *) & output[ninputs],
  257.         noutputs, 0, SETSVMENU);
  258.     for (i = 0; i < nunits; i++)
  259.     output[i] = 0.0;
  260.  
  261.     error = (float *) emalloc((unsigned)(sizeof(float) * nunits));
  262.     (void) install_var("error", Vfloat,(int *) & error[ninputs], 
  263.                 noutputs, 0, SETSVMENU);
  264.     for (i = 0; i < nunits; i++)
  265.     error[i] = 0.0;
  266.  
  267.     target = (float *) emalloc((unsigned)(sizeof(float) * noutputs));
  268.     (void) install_var("target", Vfloat,(int *) target, noutputs, 0,
  269.                SETSVMENU);
  270.     for (i = 0; i < noutputs; i++)
  271.     target[i] = 0.0;
  272.  
  273.     input = (float *) emalloc((unsigned)(sizeof(float) * ninputs));
  274.     (void) install_var("input", Vfloat,(int *) input, ninputs, 0, SETSVMENU);
  275.     
  276.     for (i = 0; i < ninputs; i++)
  277.     input[i] = 0.0;
  278.  
  279.     System_Defined = TRUE;
  280.     return(TRUE);
  281. }
  282.  
  283.  
  284. float  logistic (x)
  285. float  x;
  286. {
  287.     x /= temp;
  288.     if (x > 11.5129)
  289.     return(.99999);
  290.       else
  291.     if (x < -11.5129)
  292.         return(.00001);
  293.     else
  294.        return(1.0 / (1.0 + (float) exp( (double) ((-1.0) * x))));
  295. }
  296.  
  297. probability(val)
  298. float  val;
  299. {
  300.     return((rnd() < val) ? 1 : 0);
  301. }
  302.  
  303.  
  304. compute_output() {
  305.     register int    i,j,sender,num;
  306.  
  307.     for (i = ninputs; i < nunits; i++) {/* ranges over output units */
  308.     netinput[i] = bias[i];
  309.     sender = first_weight_to[i];
  310.     num = num_weights_to[i];
  311.     for (j = 0; j < num; j++) { /* ranges over input units */
  312.         netinput[i] += output[sender++]*weight[i][j];
  313.     }
  314.     if (linear) {
  315.       output[i] = netinput[i];
  316.     }
  317.     else if (lt) {
  318.       output[i] = (float) (netinput[i] > 0 ? 1.0 : 0.0 );
  319.     }
  320.     else if    (cs) {
  321.       output[i] =  logistic(netinput[i]);
  322.     }
  323.     else { /* default, stochastic mode */
  324.       output[i] = (float)probability((float)logistic(netinput[i]));
  325.     }
  326.     }
  327. }
  328.  
  329. compute_error() {
  330.     register int    i,j;
  331.  
  332.     for (i = ninputs, j = 0; i < nunits; j++, i++) {
  333.     error[i] = target[j] - output[i];
  334.     }
  335. }
  336.  
  337. change_weights() {
  338.     register int    i,j,ti,sender,num;
  339.  
  340.     if (hebb) {
  341.       for (i = ninputs,ti = 0; i < nunits; i++,ti++) {
  342.         output[i] = target[ti];
  343.     sender = first_weight_to[i];
  344.     num = num_weights_to[i];
  345.     for (j = 0; j < num; j++) {
  346.          weight[i][j] +=
  347.             epsilon[i][j]*output[i]*output[sender++];
  348.     }
  349.     bias[i] += bepsilon[i]*output[i];
  350.       }
  351.     }
  352.     else { /* delta rule, by default */
  353.       for (i = ninputs; i < nunits; i++) {
  354.     sender = first_weight_to[i];
  355.     num = num_weights_to[i];
  356.     for (j = 0; j < num; j++) {
  357.          weight[i][j] +=
  358.             epsilon[i][j]*error[i]*output[sender++];
  359.     }
  360.     bias[i] += bepsilon[i]*error[i];
  361.       }
  362.     }
  363. }
  364.  
  365. constrain_weights() {
  366. }
  367.  
  368. setinput() {
  369.     register int    i;
  370.  
  371.     for (i = 0; i < ninputs; i++) {
  372.         output[i] = input[i];
  373.     }
  374.     if (patno < 0) cpname[0] = '\0';
  375.     else strcpy(cpname,pname[patno]);
  376. }
  377.  
  378. trial() {
  379.     setinput();
  380.     compute_output();
  381.     compute_error();
  382.     sumstats();
  383. }
  384.  
  385. sumstats() {
  386.  
  387.     pss  =  (float) sumsquares(target,&output[ninputs],noutputs);
  388.     vcor =  (float) veccor(target,&output[ninputs],noutputs);
  389.     nvl  =  (float) veclen(&output[ninputs],noutputs);
  390.     ndp  =  (float) dotprod(target,&output[ninputs],noutputs);
  391.     tss += pss;
  392. }
  393.  
  394. ptrain() {
  395.   train('p');
  396. }
  397.  
  398. strain() {
  399.   train('s');
  400. }
  401.  
  402. train(c) char c; {
  403.     int     t,i,old,npat;
  404.     char    *str;
  405.  
  406.     if (!System_Defined)
  407.     if (!define_system())
  408.         return;
  409.  
  410.     for (t = 0; t < nepochs; t++) {
  411.     if (!tallflag) epochno++;
  412.     for (i = 0; i < npatterns; i++)
  413.         used[i] = i;
  414.     if (c == 'p') {
  415.       for (i = 0; i < npatterns; i++) {
  416.         npat = rnd() * (npatterns - i) + i;
  417.         old = used[i];
  418.         used[i] = used[npat];
  419.         used[npat] = old;
  420.       }
  421.     }
  422.     tss = 0.0;
  423.     for (i = 0; i < npatterns; i++) {
  424.         if (Interrupt) {
  425.         Interrupt_flag = 0;
  426.         update_display();
  427.         if (contin_test() == BREAK) return(BREAK);
  428.         }
  429.         patno = used[i];
  430.         distort(input,ipattern[patno],ninputs,noise);
  431.         distort(target,tpattern[patno],noutputs,noise);
  432.         trial();
  433.         /* the && lflag insures that we do not get a redundant
  434.            display update if change_weights is not going to be
  435.            called */
  436.         if (step_size == CYCLE && lflag) {
  437.         update_display();
  438.             if (single_flag) {
  439.            if (contin_test() == BREAK) return(BREAK);
  440.         }
  441.         }
  442.         if (lflag) change_weights();
  443.         if (step_size <= PATTERN) {
  444.           update_display();
  445.           if (single_flag) {
  446.         if (contin_test() == BREAK) return(BREAK);
  447.           }
  448.         }
  449.     }
  450.     if (step_size == EPOCH) {
  451.      update_display();
  452.      if (single_flag) {
  453.         if (contin_test() == BREAK) return(BREAK);
  454.      }
  455.         }
  456.     if (tss < ecrit)
  457.         break;
  458.     }
  459.     if (step_size == NEPOCHS) {
  460.     update_display();
  461.     }
  462.     return(CONTINUE);
  463. }
  464.  
  465. tall() {
  466.   int save_lflag;
  467.   int save_single_flag;
  468.   int save_nepochs;
  469.   int save_step_size;
  470.   
  471.   save_lflag = lflag;  lflag = 0;
  472.   save_single_flag = single_flag; 
  473.   if (in_stream == stdin) single_flag = 1;
  474.   save_nepochs = nepochs;  nepochs = 1;
  475.   save_step_size = step_size; if (step_size > PATTERN) step_size = PATTERN;
  476.   tallflag = 1;
  477.   train('s');
  478.   tallflag = 0;
  479.   lflag = save_lflag;
  480.   nepochs = save_nepochs;
  481.   single_flag = save_single_flag;
  482.   step_size = save_step_size;
  483. }
  484.   
  485. test_pattern() {
  486.     char   *str;
  487.     float *ivec, *tvec;
  488.     float tmp_noise;
  489.  
  490.     if(! System_Defined)
  491.       if(! define_system())
  492.        return(CONTINUE);
  493.  
  494.     str = get_command("input (#N, ?N, E for enter): ");
  495.     if (str == NULL) return(CONTINUE);
  496.     if(*str == '#' || *str == '?') {
  497.     if((patno = get_pattern_number(str+1)) < 0) {
  498.        return(put_error("Invalid pattern specification."));
  499.     }
  500.     tmp_noise = (float) (*str = '#' ? 0.0 : noise );
  501.         distort(input, ipattern[patno], ninputs, tmp_noise);
  502.     }
  503.     else {
  504.     patno = -1;
  505.     if ((ivec = readvec(" input ",ninputs)) == (float *) NULL) 
  506.         return(CONTINUE);
  507.         distort(input, ivec, ninputs, 0.0);
  508.     }
  509.     str = get_command("target (#N, ?N, E for enter): ");
  510.     if (str == NULL) {
  511.     tvec = readvec(" target ",noutputs);
  512.     }
  513.     else if(*str == '#' || *str == '?') {
  514.     if((patno = get_pattern_number(str+1)) < 0) {
  515.        return(put_error("Invalid pattern specification."));
  516.     }
  517.     tmp_noise = (float) (*str = '#' ? 0.0 : noise );
  518.         distort(target, tpattern[patno], noutputs, tmp_noise);
  519.     } 
  520.     else {
  521.     if ((tvec = readvec(" target ",noutputs)) == (float *) NULL) 
  522.         return(CONTINUE);
  523.         distort(target, tvec, noutputs, 0.0);
  524.     }
  525.     trial();
  526.     update_display();
  527.     return(CONTINUE);
  528. }
  529.  
  530. newstart() {
  531.     random_seed = rand();
  532.     reset_weights();
  533. }
  534.  
  535. reset_weights() {
  536.     register int    i,j,end;
  537.     
  538.     epochno = 0;
  539.     tss = 0.0;
  540.     pss = 0.0;
  541.     patno = 0;
  542.     ndp = vcor = nvl = 0.0;
  543.     cpname[0] = '\0';
  544.     
  545.     srand(random_seed);
  546.  
  547.     if (!System_Defined)
  548.     if (!define_system())
  549.         return;
  550.  
  551.     for (j = ninputs; j < nunits; j++) {
  552.     for (i = first_weight_to[j], end = i + num_weights_to[j];
  553.          i < end; i++) {
  554.         weight[j][i] = 0.0;
  555.     }
  556.     bias[j] = 0.0;
  557.     }
  558.     for (i = 0; i < ninputs; i++) {
  559.       input[i] = 0.0;
  560.     }
  561.     for (i = 0; i < noutputs; i++) {
  562.       target[i] = 0.0;
  563.     }
  564.     for (i = 0; i < nunits; i++) {
  565.       output[i] = error[i] = 0.0;
  566.     }
  567.     update_display();
  568. }
  569.  
  570. init_weights() {
  571.     (void) install_command("network", define_network, GETMENU,(int *) NULL);
  572.     (void) install_command("weights", read_weights, GETMENU,(int *) NULL);
  573.     (void) install_command("weights", write_weights, SAVEMENU,(int *) NULL);
  574.     (void) install_var("nunits", Int,(int *) & nunits, 0, 0, SETCONFMENU);
  575.     (void) install_var("ninputs", Int,(int *) & ninputs, 0, 0, SETCONFMENU);
  576.     (void) install_var("noutputs", Int,(int *) & noutputs, 0, 0, SETCONFMENU);
  577. }
  578.